{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Mujoco Control \n",
    "This tutorial introduces how to set up closed-loop control system for multi-body robotic systems in Mujoco. \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import mujoco\n",
    "import mediapy as media\n",
    "import numpy as np\n",
    "\n",
    "import mujoco.viewer as viewer\n",
    "modelxml = \"\"\"\n",
    "<mujoco model=\"CartPole\">\n",
    "  <compiler eulerseq=\"XYZ\"/>\n",
    "  <default>\n",
    "    <default class=\"unused\"/>\n",
    "  </default>\n",
    "  <asset>\n",
    "    <texture name=\"grid\" type=\"2d\" builtin=\"checker\" rgb1=\"0.1 0.2 0.3\" rgb2=\"0.2 0.3 0.4\" width=\"512\" height=\"512\"/>\n",
    "    <material name=\"grid\" class=\"unused\" texture=\"grid\" texrepeat=\"1 1\" texuniform=\"true\" reflectance=\"0.2\"/>\n",
    "  </asset>\n",
    "  <worldbody>\n",
    "    <geom name=\"floor\" class=\"unused\" type=\"plane\" condim=\"3\" size=\"0 0 0.05\" material=\"grid\" pos=\"0 0 -1\"/>\n",
    "    <body name=\"Cart\" pos=\"0 0 0\" euler=\"0 -0 0\">\n",
    "    <!-- change inertia, different from sdf -->\n",
    "    <!-- For this model case, with the cart not having any rotational\n",
    "             degrees of freedom, the values of the inertia matrix do not\n",
    "             participate in the model. Therefore we just set them to zero\n",
    "             (or near to zero since sdformat does not allow exact zeroes\n",
    "             for inertia values). -->\n",
    "      <inertial pos=\"0 0 0\" mass=\"5\" diaginertia=\"0.00000000001 0.00000000001 0.00000000001\"/>\n",
    "      <geom name=\"cart_visual\" class=\"unused\" type=\"box\" contype=\"0\" conaffinity=\"0\" group=\"0\" size=\"0.12 0.06 0.06\" pos=\"0 0 0\" euler=\"0 -0 0\"/>\n",
    "      <joint name=\"CartSlider\" class=\"unused\" type=\"slide\" pos=\"0 0 0\" axis=\"1 0 0\"/>\n",
    "      <body name=\"Pole\" pos=\"0 0 -0.5\" euler=\"0 -0 0\">\n",
    "        <inertial pos=\"0 0 0\" mass=\"1\" diaginertia=\"0.00000000001 0.00000000001 0.00000000001\"/>\n",
    "        <geom name=\"pole_point_mass\" class=\"unused\" type=\"sphere\" contype=\"0\" conaffinity=\"0\" group=\"0\" size=\"0.05\" pos=\"0 0 0\" euler=\"0 -0 0\"/>\n",
    "        <geom name=\"pole_rod\" class=\"unused\" type=\"cylinder\" contype=\"0\" conaffinity=\"0\" group=\"0\" size=\"0.025 0.25\" pos=\"0 0 0.25\" euler=\"0 -0 0\"/>\n",
    "        <joint name=\"PolePin\" class=\"unused\" type=\"hinge\" pos=\"0 0 0.5\" axis=\"0 -1 0\"/>\n",
    "      </body>\n",
    "    </body>\n",
    "  </worldbody>\n",
    "  <sensor>\n",
    "        <!-- joint position sensing -->\n",
    "        <jointpos joint=\"CartSlider\" name=\"cart_p\"/>\n",
    "        <jointpos joint=\"PolePin\" name=\"pole_theta\"/>\n",
    "        <!-- joint velocity sensing -->\n",
    "        <jointvel joint=\"CartSlider\" name=\"cart_v\"/>\n",
    "        <jointvel joint=\"PolePin\" name=\"pole_w\"/>\n",
    "  </sensor>\n",
    "  <actuator>\n",
    "    <motor joint=\"CartSlider\"/>\n",
    "  </actuator>\n",
    "  <keyframe>\n",
    "    <key name=\"off1\" qpos=\"0 3.15\"/>\n",
    "  </keyframe>\n",
    "</mujoco>\n",
    "\"\"\"\n",
    "model = mujoco.MjModel.from_xml_string(modelxml)\n",
    "data = mujoco.MjData(model)\n",
    "mujoco.mj_resetDataKeyframe(model, data, 0)\n",
    "viewer.launch(model,data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Simulatioin of open-loop resposne \n",
    "now let's simulate the physics\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table class=\"show_videos\" style=\"border-spacing:0px;\"><tr><td style=\"padding:1px;\"><img width=\"320\" height=\"240\" style=\"image-rendering:auto; object-fit:cover;\" src=\"\"/></td></tr></table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "duration = 3    # (seconds)\n",
    "framerate = 30  # (Hz)\n",
    "\n",
    "# Simulate and display video.\n",
    "frames = []\n",
    "mujoco.mj_resetDataKeyframe(model, data, 0)  # Reset the state to keyframe 0\n",
    "with mujoco.Renderer(model) as renderer:\n",
    "  while data.time < duration:\n",
    "    mujoco.mj_step(model, data)\n",
    "    if len(frames) < data.time * framerate:\n",
    "      renderer.update_scene(data)\n",
    "      pixels = renderer.render()\n",
    "      frames.append(pixels)\n",
    "media.show_video(frames, fps=framerate, codec='gif')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# define customized controller which returns the feedback control action\n",
    "# if we want to use control callback, we need to set data.ctrl, the returned value does not matter in this case\n",
    "def myControl(model, data):\n",
    "    x = np.hstack((data.qpos,data.qvel))\n",
    "    xref = np.array([0, np.pi, 0, 0])\n",
    "    x_error = x-xref\n",
    "    K = np.array([-0.5, 0.5, 0, 0])\n",
    "    u = K@x_error\n",
    "    data.ctrl = u\n",
    "    return u\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "import mujoco.viewer as viewer\n",
    "import time\n",
    "\n",
    "mujoco.mj_resetDataKeyframe(model, data, 0)  # Reset the state to keyframe 0\n",
    "\n",
    "# \n",
    "with viewer.launch_passive(model, data) as viewer:  \n",
    "  # launch_passive means all the simulation should be done by the user \n",
    "  \n",
    "  start = time.time()\n",
    "  while viewer.is_running() and time.time() - start < 10:\n",
    "    step_start = time.time()\n",
    "    data.ctrl = 0.2*myControl(model,data)\n",
    "    mujoco.mj_step(model, data)\n",
    "\n",
    "  # let viewer show updated info\n",
    "    viewer.sync()\n",
    "    \n",
    "  # #  make sure the while loop is called every sampling period \n",
    "    # computation inside the loop may take some nontrivial time. \n",
    "    time_until_next_step = model.opt.timestep - (time.time() - step_start)\n",
    "    if time_until_next_step > 0:\n",
    "      time.sleep(time_until_next_step)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Simulation method 1: \n",
    "interactive simulation, the viewer can respond to user's input during the simulation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# simulating the closed-loop system by using the control callback\n",
    "\n",
    "import mujoco.viewer as viewer\n",
    "mujoco.mj_resetDataKeyframe(model, data, 0)  # Reset the state to keyframe 0\n",
    "mujoco.set_mjcb_control(myControl)\n",
    "viewer.launch(model,data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Simuation method 2:\n",
    "Passive simulation, the user has the control over stepping the physics "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "import mujoco.viewer as viewer\n",
    "import time\n",
    "\n",
    "mujoco.mj_resetDataKeyframe(model, data, 0)  # Reset the state to keyframe 0\n",
    "with viewer.launch_passive(model,data) as viewer:\n",
    "    while viewer.is_running() and data.time < 5:\n",
    "        data.ctrl = 20*myControl(model,data)  # apply control\n",
    "        mujoco.mj_step(model, data)            # step xdot= f(x,u)\n",
    "        viewer.sync()\n",
    "        time.sleep(model.opt.timestep)        \n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "sim",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
